{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bernoulli Distribution\n",
    "\n",
    "Bernoulli distribution is a special case of binomial distribution, where n = 1, i.e., a single success\n",
    "or fail type trial is performed\n",
    "\n",
    "Let us reconsider the familiar dataset of photos described in the Binomial distribution notebook, which contains 20% of celebrity faces.  The probability of success (picking up a celebrity photo) in a single trial is 0.2 `(p=0.2)`. Probability of failure (not picking up a celebrity photo) is `1 - p = 0.8`.\n",
    "\n",
    "Formally, $$P(X=1) = p$$  $$P(X=0) = 1 - p$$ where 1 represents success and 0 represents failure\n",
    "\n",
    "This is demonstrated by the PyTorch code below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.distributions import Bernoulli\n",
    "\n",
    "# Set the parameters of the distribution\n",
    "p = torch.tensor([0.2], dtype=torch.float)\n",
    "\n",
    "# Instantiate the uniform distribution\n",
    "bern_dist = Bernoulli(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Log Prob: -1.609\n",
      "Raw eval Log Prob: -1.609\n"
     ]
    }
   ],
   "source": [
    "# Instantiate single point test dataset\n",
    "X = torch.tensor([1], dtype=torch.float)\n",
    "\n",
    "# Function to evaluate log prob using math formula\n",
    "def raw_eval(X, p):\n",
    "    prob = p if X == 1 else 1-p\n",
    "    return torch.log(prob)\n",
    "\n",
    "# Evaluate log-prob using PyTorch distributions function call\n",
    "log_prob = bern_dist.log_prob(X)\n",
    "print(\"Log Prob: {:.3f}\".format(log_prob[0]))\n",
    "\n",
    "# Evaluate log-prob using formula\n",
    "raw_eval_log_prob = raw_eval(X, p)\n",
    "print(\"Raw eval Log Prob: {:.3f}\".format(raw_eval_log_prob[0]))\n",
    "\n",
    "assert torch.isclose(log_prob, raw_eval_log_prob, atol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of samples to draw\n",
    "num_samples = 100000\n",
    "\n",
    "# Draw samples\n",
    "samples = bern_dist.sample([num_samples])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample Mean: 0.2002899944782257\n",
      "Dist Mean: 0.200\n"
     ]
    }
   ],
   "source": [
    "# The mean obtained from the samples\n",
    "sample_mean = samples.mean()\n",
    "print(\"Sample Mean: {}\".format(sample_mean))\n",
    "\n",
    "# The mean of the distribution from Pytorch\n",
    "dist_mean = bern_dist.mean\n",
    "print(\"Dist Mean: {:.3f}\".format(dist_mean[0]))\n",
    "\n",
    "# As expected, the two means approximately match\n",
    "assert torch.isclose(sample_mean, dist_mean, atol=0.2)\n",
    "\n",
    "# The variance obtained from the samples\n",
    "sample_var = bern_dist.sample([num_samples]).var()\n",
    "\n",
    "# The variance of the distribution from Pytorch\n",
    "dist_var = bern_dist.variance\n",
    "\n",
    "# As expected, the two variances approximately match\n",
    "assert torch.isclose(sample_var, dist_var, atol=0.2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
